{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data for Deep Ensembles as BMA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "import torch.utils.data\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "import seaborn as sns\n",
    "\n",
    "from swag import data, models, utils, losses\n",
    "from swag.posteriors import SWAG\n",
    "\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import os\n",
    "\n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.manual_seed(1)\n",
    "torch.cuda.manual_seed(1)\n",
    "np.random.seed(1)\n",
    "\n",
    "import hamiltorch\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RegNet(nn.Sequential):\n",
    "    def __init__(self, dimensions, input_dim=1, output_dim=1, apply_var=True):\n",
    "        super(RegNet, self).__init__()\n",
    "        self.dimensions = [input_dim, *dimensions, output_dim]        \n",
    "        for i in range(len(self.dimensions) - 1):\n",
    "            self.add_module('linear%d' % i, torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1]))\n",
    "            if i < len(self.dimensions) - 2:\n",
    "                self.add_module('relu%d' % i, torch.nn.ReLU())\n",
    "\n",
    "        if output_dim == 2:\n",
    "            self.add_module('var_split', SplitDim(correction=apply_var))\n",
    "\n",
    "    def forward(self, x, output_features=False):\n",
    "        return super().forward(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = RegNet(dimensions=[10, 10, 10], input_dim=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f0a689835d0>]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2deZQU1dXAf5cZBgQRBcYNhAHBGJe4MIIkaozEDRUUiYJESVzABXc/IyHJURMTjYnGBRdcIjqjoIJKjIIL7iu4AiIGUBBUdlkVGOZ+f7wa6Wl6qZ6p7qruvr9z6kxX1atXt2pe3bp13333iapiGIZhFC5NwhbAMAzDyC6m6A3DMAocU/SGYRgFjil6wzCMAscUvWEYRoFTGrYA8bRr104rKirCFsMwDCOveO+995apanmifZFT9BUVFUybNi1sMQzDMPIKEZmfbJ+5bgzDMAocU/SGYRgFjil6wzCMAscUvWEYRoFjit4wDKPAMUVvGIZR4JiiNwzDKHBM0Rt5zezZcOONcOutMG4cvPoqfP992FIloboaKiqgSRP3t7o6bImMIiFyA6YMIx2rV8PYsfDAA/DWW1vv32EHOP10OOcc2GefnIuXmOpqGDoU1q936/Pnu3WAwYPDk8soCsyiN/ICVXj7bTjrLNhlFxg2zCn8G2+EhQthyRKYPh0mToSjjoK77oJ994UhQ7bo1lAZOXJrQdavd9sNI8v4suhF5BjgFqAEuFdVr4/b3wx4EOgOLAdOVdUvRKQMuBuoBGqBi1X15eDENwqV2lpn9L7zDkye7Javv4aWLeG005y1ftBBILLlmPJyZ8GfcAIsWwb/+hf89a/wwQcwfjx065b6fFdf7c5TW+uWn/8c/vnP+udoMAsWZLbdMAIkraIXkRJgFHAksBCYKiITVfWTmGJnAStVtauIDARuAE4FzgFQ1X1FZEfgWRE5SFVrg74QI3/54guoqoLFi2HpUvjyS2edr1nj9u+wAxx5JBx7LPTvD9ttl77Odu3gL3+BQw91L4bKSnjkEejTZ+uy333nXD3jx8Mhh7j6V62Cm2+GHj1g4MAALrJjR/fmSrTdMLKNqqZcgF7A5Jj1EcCIuDKTgV7e71JgGSC4F8TpMeVeBHqkOl/37t3VKB6++EK1QwdVUG3dWrVbN9Wf/1x1+HDVu+9Wfecd1U2bGneO+fNVDzxQtWlT1SefrL9v2TLVn/1MVUT1ppu2bK+pUa2sVN1pJ9UVKxp3flVVrapSbdHCXWjd0qKF224YAQBM02R6PNkO3aKcB+DcNXXrpwO3x5WZAXSIWZ8LtAOGAo95yr8z8C1wcoJzDAWmAdM6duyYsxtjhMvXX6t27aq6/faqH36Y3XOtXKnas6dqaanqhAmqq1er/v3vTpE3a6b66KNbH/P++6pNmqgOGxaQEFVVqp06ubdKp06m5I1ASaXos90Zez/O3TMN+BfwJrA5vpCqjlbVSlWtLC9PmE7ZiACqcPfd8P77ja9rxQrnjvn6a3j2Wdhvv8bXmYrtt3f+98pK+NWvnMfkyitdh+1rr7lt8RxwAFxyibvmRNE9GTN4sPNT1da6vxZtY+QIP4p+EbBbzHoHb1vCMiJSCrQGlqtqjapeqqr7q2o/YHvgs8aLbYTBU0/Buee6TtCLL3ZRLw1h40bo1w8++8zVefDBwcqZjNatnbI/9lg44gh49114/nl3Pcm45hrYbTcX5VNTkxs5DSNo/Cj6qUA3EensRdEMBCbGlZkIDPF+DwCmqKqKSAsRaQkgIkcCNVq/E9fIE2pqYMQI+NGP4Lzz4LbbYM894eGHnaXvF1V3/Ouvw4MPQu/e2ZM5EdttB//5j+t4TaXg69h2Wxe9M3063Htv9uUzjKyQzKej9X3ofXCW+FxgpLftWqCv97s5zhc/B3gX6OJtrwBmA7OAF4BO6c5lnbHR5J57XI/OhAlu/d13Vbt3d9t69lR9801/9dx0kzvmj3/MnqxBU1vrOojLy1W//TZsaQwjMTSmMzbXiyn66LFunequu6oefLBTenVs3qz673+r7rKLa0knnqj61luJ66itVX3kEde52b+/OzafeO8914d65ZVhS2IYiUml6G1krJGWW2+Fr76CG26oP3ioSRP4zW+cr/3qq+GVV6BXLzfQ6L774MMPnT/+1VddPPugQa6D88EH3bH5xIEHwhlnODfOvHlhS2MYmSGaiYM1B1RWVqpNDh4dVq6Ezp2dov7Pf1KXXbvW+bFvuskNegIoLXX+/V13hT/8waUwKCvLvtzZYNEi2GMPN+jqscfClsYw6iMi76lqZaJ9ltTMSMktt7hRotddl77sttu6cMSLLoK5c10Y5gcfOCV/zjmwzTbZlzebtG/vQjKvvhqmTvXXmWsYUcAseiMpq1a5bLqHHw5PPBG2NNFgzRr3hVNZCZMmhS2NYWwhlUWfZ55SI5fcfjt8+61zuRiOVq3gqqtcPP5rr4UtjWH4wxS9kZC1a11Sr+OOg+7dw5YmWpx/Puy8s3sBRuyD2DASYoreSMidd8Ly5fDHP4YtSfRo0cKlkX/1VXjxxbClMYz0mI/e2Ir1650fer/94LnnwpYmmmzY4CJwdt7ZTYgSSM56w2gE5qM3MuL++92MTeabT06zZu5r59137WVoRB9T9EY9amrcrEq9ernYeSM5Z5wBHTrA3/7WiEpswnAjB5iiN+rx2GMug+7vfmfuiHSUlcEVV7gRwW++2YAK6iYMnz/f9erWTRhuyt4IGPPRGz+g6lIUbNwIM2bkX5qCMFi3Djp1cl9A6UYOb0VFReLpBTt1cm9bw8gA89EbvnjuOfjoI/i//zMl75eWLd1o4Kefho8/zvBgmzDcyBH2OBs/8Pe/u3QFp50WtiT5xQUXuPQP11/vo3CsTz7Z29QmDDcCxhS9AcB778GUKXDppS6ixPDPDju4QVTjxqXJbBnvk9+81ayaLkjfT2Ihw8gAU/QG4EbBtmrl9JCRORdfDCUlLqVzUkaOdIMU4ikpcT3fnTrB6NE2l6wROKboDRYtctboWWe5qfaMzNl1Vzj1VDcGIelcusl877W1NmG4kVVM0RvccYfzIlx4YdiS5DeXXOKyW953X5ICyXzv5pM3sowvRS8ix4jIbBGZIyJXJdjfTETGefvfEZEKb3tTERkjItNFZJaIjAhWfKOxrF8Pd90FJ54IXbqELU1+0707HHKIc98kcr9z3XXOBx+L+eSNHJBW0YtICTAKOBbYCxgkInvFFTsLWKmqXYGbgRu87b8CmqnqvkB3YFjdS8CIBlVVsGKFs0aNxnPppc4D89RTCXYOHux88J06mU/eyCl+LPoewBxVnaeqG4GxQL+4Mv2AMd7vx4HeIiKAAi1FpBTYBtgIJPNgGqnIwlB5VTcH6gEHWLqDoOjXz/17/vWvJAUGD3ZvAvPJGznEj6JvD3wZs77Q25awjKrWAKuAtjilvw74GlgA/ENVV8SfQESGisg0EZm2dOnSjC+i4MnSUPnnn4dZs5wVaukOgqGkxE2l+NprbipFw4gC2e6M7QFsBnYFOgOXi8hWnmBVHa2qlapaWV5enmWR8pBEYXnr17vtjeCOO6C8HE45pVHVGHH89rfO9X7nnWFLYhgOP4p+EbBbzHoHb1vCMp6bpjWwHDgNmKSqm1R1CfAGkDAXg5GCLAyVX7jQ5WY580wbIBU0228PgwbBww+7eXcNI2z8KPqpQDcR6SwiZcBAYGJcmYnAEO/3AGCKumxpC4AjAESkJXAw8GkQghcVWQjLu+ce5wUaNqzBVRgpOO8899H14INhS2IYPhS953MfDkwGZgGPqupMEblWRPp6xe4D2orIHOAyoC4EcxSwrYjMxL0w/q2qmaZ+MgIOy9u0ySn6Y45xM0kZwdO9Oxx0kHPfRCxBrFGEWJrifKG62vnkFyyANm3cthUrnFV/3XUZRW9MmAAnn+xCAPv2TV/eaBj33+9GG7/yChx2WNjSGIWOpSkuBOrC8h56CL77zs3cnSgCx0cY5p13wm67wXHH5fICio+BA52/3jpljbApDVsAI0PSReAMHbplf91LAH6w+P/3P3jhBfjzn10ooJE9WrSAIUNcdNPixbDTTmFLZBQrZtHnG6kicHyEYd57r1PwZ52VRRmNHzj3XNcnkrZT1uaONbKIKfp8I1UETpowzE2bYMwY57LZZZcsyWfUY8894ac/df76pN1hNneskWVM0ecbqSJw0oRhTprkXAhnnpllGY16nHkmfPopvP12kgJZGhBnGHWYos83UiXGShOGef/9sOOO0KdPCHIXMaec4v4N99+fpIDNHWtkGVP0+UiyxFgpXgKLF7sJrM84A5o2DVP44qNVK6fsx46FdesSFLA89UaWMUVfaCR5CVRVQU2Ny8Ni5J4zz4S1a2H8+AQ7LU+9kWVM0RcBqs5tcPDBsFf8TAJGTjjkEOjaNYn7xvLUG1nGFH0RMHUqfPKJdcKGiYi7/6+8AnPmJChgeeqNLGKKvgh44AHYZhs3ebURHmec4RS+JTozco0p+gJn40YYN87NCbvddmFLU9y0bw+9e7vw+IilmDIKHFP0Bc6kSS732a9/HbYkBrj/w7x58NZbYUtiFBOm6Aucqio3i9SRR4YtiQFw0knOjVZVFbYkRjFhir6AWbUKJk50WRQtdj4abLedm0B83DjnVjOMXGCKvoAZPx42bDC3TdT49a+dO23SpLAlMYoFU/QFTHU1dOvmZjoyIkJ1NUed35V2LKVq0NOWuMzICaboC5SFC+Gll5z1KBK2NAbwQ5bKpgvmMpCxTFz/S1adc0V6ZW8pjI1G4kvRi8gxIjJbROaIyFUJ9jcTkXHe/ndEpMLbPlhEPoxZakVk/2AvwUjEI4+4ED4bdxMhYrJUns5DbKA54787NnWWSkthbARAWkUvIiW4Sb6PBfYCBolI/ED6s4CVqtoVuBm4AUBVq1V1f1XdHzgd+FxVPwzyAowEVFcz9g/T6cnb7N67wpRCVIjJRnkQU9mdOYxl4JbtiSx3S2FsBIAfi74HMEdV56nqRmAs0C+uTD9gjPf7caC3yFYOg0HesUY2qa5mztnX8/7GfTmVcWYBRomYbJQCnMKjTOEIlrbfP7nlPn9+4roshbGRAX4UfXvgy5j1hd62hGVUtQZYBbSNK3Mq8EiiE4jIUBGZJiLTli5d6kfu6JMNv6qfOkeOZNz3fQEYwONum1mA0SAuS+WpjGMzpUw48s7klnuyiX0thbGRATnpjBWRnsB6VZ2RaL+qjlbVSlWtLC8vz4VI2SUbflW/dS5YwKOcws94nd1YWG+7ETJxWSp/0nEVe+y8mnFf9Ez+/9m8ObgUxoXeqVvo19cYVDXlAvQCJsesjwBGxJWZDPTyfpcCywCJ2X8z8Pt051JVunfvrnlPp06qTh3XXzp1ynqds3Y9QkH1VoYHd24ja/zxj6pNmqh+3aEy+f+3qsr9FdmynilVVaotWtSvu0WLhtUVRQr9+nwATNNkejzZjh8KOMU9D+gMlAEfAXvHlbkAuMv7PRB4NGZfE2AR0CXdubRQFL1I4odWJPkx6R5mn3Ve3f8jFTbrV+xctA0+n5gxw/2Lbjvj3ewqqmwYH1Gi0K/PB41S9O54+gCfAXOBkd62a4G+3u/mwGPAHODdWKUOHA687ec8WiiKPtNG58ca8VFnba3qj3+s+vM9v2m8BWhkH+/lvjfT9dBmb6ued172/m8NMT7yiUK/Ph80WtHncikIRZ/pZ6SfF4OPOqdPd5vvuCOrV2cEQcz/81r+oMJmXdh89+y9lAvd4i306/OBKfow8OtXrapK3EATWSNp6vxDv4+1CTW6mB3Nko86MYrpU/ZQUL2Zi7OnmArdh13o1+cDU/RRJVHjbKg1UlWlP5ZP9AheKNqGnlfEuRp+wod6CK9m/HLPiCDriiKFfn1pMEUfVZJ9bjZASX/iRduM4ryGvyyM3BH3v7+aP6mw2UXf1GFWqpEBqRS9JTULk1Sx7aNHZ5So5omvDgbgRJ70fw4jPOIGT/VnAkoTnjry9i1lLP2BERCm6MMk2ejGTp0yzkY2oeRX9OJNduVrf+cwwiVu8NQ+HdfQdafVTFjUc0uZZC9pe3kbGWKKPkzirDqgQaMe59/yJO9t3p/+TKi/o6ysYSMojdwweDB88QXU1iLzv6D/kO2YMgVWrvT2J3tJ28s7GIpoJK0p+iDJtOHEWXV06pSxywbgiWs+BuAknqi/o1Ury1OcR/TvDzU18PTT3oaADAEjARFL/6wKF10Ed9+dtROE3wEbu+RtZ2yuO85iIgwO5RXdjw+27tAtosEihcDmzart26ueeGLMxthIkrZt3VKkUSWBErG4+1tvdae/8sqG14FF3eSAXDacmJfKN+yowma9hj9GptEaDWf4cNXmzVXXro3bYRE4wRKhkbTPPuvyHfXtq1pT0/B6Uil6c90ERS47zmKiMZ6iH0oT888XCP37w/ffJ5g43CJwgiUi/R+zZsGpp8I++0BVVfKs1I3FFH1Q5LLhxLw8JtCfrvyPvZlZv4z55/OSQw+Fdu1gQtx72yJwAiYC/R/LlsHxx0Pz5jBxontks4Up+qDIZcPxXh6racUUjuBEnmSr+b9XrAj+vEbWKS2F446DZ56BTZtidkTEAi0YAgqEaCgbNsBJJ8GiRfDkk+702cQUfVDksuF4L5XnOIpNlNGXiVuXadKkoMPF8po00Vl9+8K338Ibb8RsjIAFWnDEhLfyxRc5U/KqcPbZ8PrrMGYM9OqVk5OG3wEbu+RtZ2yuqarS08vGahuW6SZKEncsWWdd9EjUqSriUhR7rFmjWlameumlCY4t4lwuhcK117p/+5//HGy9WNRN4bFpk2pbWaanMyaxkrfIm2iSLDpLpJ7iPvZY1d13d3MMGBkS4RdidbX7d59+evD/21SK3lw3ecpb1zzHcm2b2G0Ti3XWRYtk/w/VehE0ffvC3LkuKsPIgIgNhIrltdfgt7+Fww6De+5xHt5cYYo+H0jg05146+eUsYGjmZz6WOusixap/h8xL4Hjj3d/J6Z5jxtxRDQMdfZsOPFE6NwZnngCmjXL7fl9KXoROUZEZovIHBG5KsH+ZiIyztv/johUxOz7iYi8JSIzRWS6iDQPTvwiIImFMnH14fyCl2jF2uTHWmdd9LjuuuSmXMxLoEMHOPBAU/QZE8Ew1KVLXSRVSYmLpmrTJvcypFX0IlICjAKOBfYCBonIXnHFzgJWqmpX4GbgBu/YUqAKOFdV98bNH7sJwz8JLJTZ6zvwGT/iBP6zdfkmTUIJFzN8MngwnHvu1so+wUu5b194+21YsiSD+osoUVdCIhaGun69+zpbtMi9tLt0CUWM9J2xQC9gcsz6CGBEXJnJQC/vdymwDBDcpOJV6c4Ru1hnbBwJhmr/nSsUVOc338OibPIVHzls3n/f/Vvvvz+DOos9TUKE7kFNjUtrIKL6xBPZPx+NiboBBgD3xqyfDtweV2YG0CFmfS7QDrgEeMh7EbwPXJnufKbo40gQpXEor+h+TWdGOrrA8EkKxVRbq9qhQ1ySs7pj2rbdUr5t2y1tIZ8jr4JqzxF4LmprVc8/393+227LzTnDVPRXAJ97v1sAbwG9E5xjKDANmNaxY8fc3JV8IU4RrGB7LWGT/qHfx2FLZgRBMuVcUqIqosO2rdJtm2/U77/3yldVqTZtunX5srLE9eRDFtP4F1cBfI389a/uEq64InfnTKXo/XTGLgJ2i1nv4G1LWMbzy7cGlgMLgVdVdZmqrgeeAQ6MP4GqjlbVSlWtLC8v9yFSERE34vb5dqexmVL6/G7fsCUzgiBZJ+HmzaDKcWvHsvb7prx+9Qtu+8iRcbkRPDZuTJ4RK8qRV3XBBsuXb70vAtEyDWHMGPj972HQILjhhrClcfhR9FOBbiLSWUTKgIGwVfD2RGCI93sAMMV7w0wG9hWRFt4L4OfAJ8GIXiRUV7vGvmABdOzIf/e8nDZtoEePsAUzAiGNEj6CKZSxgWdGfe42pIoe2bw5/9IkJAqHjCXPxoFMmgRnnQW9e8MDD7g+8SiQVgxVrQGG45T2LOBRVZ0pIteKSF+v2H1AWxGZA1wGXOUduxK4Cfey+BB4X1X/G/xlFChxoZW18xfw7OutOGbPz7OWztTIMdddB02bJt3dkvUczss8s+YQtyHVi6Eu0iqkRF0NIp0ij/LXSBxTp8KAAbDvvi77aFlZ2BLFkMynE9ZinbExxPlv36VSQbWq7UVhS2YESSL/dMzyLy5SUJ07V1P76PPRn52sjyLPfPSffqrarp1qRYXqV1+FIwOWAiFPibN2nqEPQi1HL384JIGMrJAmpXQfngXg2Wdx1vm//w1t224p0LYt3H9/tC33ZCTKygnumoL6Gsny2IKvvoKjj3YfUc89B7vsEmj1gWCKPsrEfbY+Qx968g7tOrUMSSAjK6RyT4jQ7bxf0rWrG1UJOOW3bNkW23fZsvxU8pA4vXdVVXDXlOXcNytXOiW/fLl7EXfrFki1gWOKPkzSWRox1s4SypnKQfRp+ny0O9eMzEll1T70ENxxB336wJQp8N13uRcv62QzL3wWc9+sW+dSG3z2mZs8pHv3RleZPZL5dMJaisZH73cEnzf440FOV1Cd9udnwpHXyC5pBvlMmuSayDP2789sQFSWJgHfsEH16KPdpN7jxzeqqsDA8tFHkAxHMQ4cqLrTTqqbN+dUSiMifPedswOGD0+wMwIjQXNGpikOsjBauKZG9ZRTXDX33dfgagLHFH0UycDSqKlRbdNGdciQ3ItpRIfjj3eTkdQjQrldckKmijvg+1Nbq3r22a6af/yjwVeRFVIpevPRh0UGWfamTXOBGUcf7W0o9gyFRcrRR7vJSObO9TZUV8OQIZHMv541ksXdz5+f+DkIcC5nVbj8crj3Xjfy9fLLM64iPJK9AcJaisaiz8DSuPZaZ+gvXZrZcUaeE+eSmX3jUwqqd9yhidtBPuW3aSghxt1ffbU7zYUXRnOKR8x1E1F8+lZ/9jPVykpvJd8zFBr+SKDIa7dpoRXla7RfP02t8Aq5PaR7wWXpuv/5T1f9b34T3X4yU/T5iPcS+JbWWsIm/X3f6W57lqIIjIiRRJEP3bZaW7VS3UiC0bH5/IWXSYdyVVXya8/CczBqlKt6wADVTZsCrz4wTNHnGzFWywROVFB9pdmRhZFz3PBHEkU2nv4Kqq/uNCBxmZKS/FTymbojc/Qc3Hefq/aEE1Q3bgy06sAxRZ9vxDTiYdyp27JaN9B0i6VjPvrCpqoq6Zfbyt321ZIS1ZH9phdOO2iI0s7Bc1Bd7f4NRx3lwlujjin6KNCAQR61oBXM0748qfU+S4spbroYSab4RFSrqvSnP1U96CAtnHbQUHdkFq9/7Fg3GOrww1XXrQus2qxiij5sGjjI4zO6KqiO4rz0Fo5ROCRTfKCqqtdcExeFFa/s8u0FkGt3ZJr78/jjzgN2yCGqa9ZkR4RsYIo+bBo4yOM2LlBQ/R+75+9nuZE5adrLW2+51UcueG1rA6Jp062nFYx62/FjCAU5n2yKcz35pGppqWqvXqqrVzf6ynKKKfqwacinaVWVnrDNc9qFOflhlRnBkUYZ1dSo7rCD6m9bjktu+edbZ30qRR6kPz7FS/TJJ917skcP1W+/Dei6cogp+rBJMwF0IkU+b55q8+ZuJnmjCEljwQ4YoNqeL7XWr6LP5/DbIF07SYyuJ+mX10peNbWitxQIQZMoPUGyNLTeBNDxObJVYdgwKC2Fq67KqfRGVEiTuvfII2ERHZjNj/zVl0dT8m1FsrQHibbHP3/nn19/vU2brQ55ir78ikc54AA3cUjr1gHKHhWSvQFiF+AYYDYwB7gqwf5mwDhv/ztAhbe9AvgON1/sh8Bd6c6V1xZ9sk/M886rP11ckyYpLZQxY9zq7beHezlGdJk717WR25peWr8N5aOPPpZEXzJ+Lfp0o2YT3J/HOFlL2ag9d1+at5Z8HTTGdQOUAHOBLkAZ8BGwV1yZ8+uUODAQGKdbFP2MdOeIXfJa0acKi/P5eb14sctU+dOfRneotRENOndW7dd9Qf5H3dSRylDy46NPlxaibmnbVrVTJ32YQVrCJv1ptyW6alUoVxwojVX0vYDJMesjgBFxZSYDvbzfpcAyQIpO0ftV6HHLJkp0Dl30mR2H6LHHOoPjk0/Cvhgj6pxzjmrr1tEelp8RqSz3ZC+v2O0ZPHMPnvuGNmmiethh+RVCmYrGKvoBwL0x66cDt8eVmQF0iFmfC7TzFP064APgFeDQJOcYCkwDpnXs2DFX9yV4fFgUtaAf8hO9mj/pyTymezNdy/i+XrEbbwz7QozIkUDRjfOCbt5+O2zhAiLT6DQ/rpoEy92co8Jm7b33V7p2bW4vMZukUvTZ7oz9GuioqgcAlwEPi8h28YVUdbSqVqpqZXl5eZZFyiKJOl1FAFhNK/7CSPbiE/bnI67lT0xnX3ZnLpfIrdx3zlu89hosXgxXXBGC7EZ0STLB9RErHgfghRfSHJsvcxdkMEcDkHg+2DT8i4sZxmiO5Vn+s/pwWrbMUMZ8JdkboG6hEa6bBHW9DFSmOl9eu25Ut7a8zjtPnyg7RdvzpYLqYbysdzJMlxLTOZvPoW9G9ontyI9zaRxwgBumn5B8y4uUqbzp3DUx+2tB/8LvFVRP5jGXO6rAnjsa6bopBeYBndnSGbt3XJkLqN8Z+6j3uxwo8X53ARYBbVKdL+8VfQxffaV60knuLv+k6Sf6DgclfWBVNX870YzskSYl7//9n+vTSeiCyMdMp5k8A+muz6urFvRKrldQ/TUP6iZKon8fGkCjFL07nj7AZzjf+0hv27VAX+93c+AxXHjlu0AXb/vJwExcaOX7wAnpztUYRf/NNw0+NFBqa12I5Pbbu0FP11/vpThNZbHkm/Vl5IZU/T6dOunkye7npEkJji30uQt8PDM1NarDjvhMQfU8RulmpGCfrUYr+lwuDVX0b7yhus024c/KvmiR6nHHuTt7yCGqs2fHFUhmseSj9WVkn1TuiaoqXbfOWfRXXJHg2GJoUym+ADZsUB00yF3yiBOma23HxOUKhaJQ9GvWqB59tLui668PZ07HRx5xOafW0Q8AABMvSURBVEi22Ub1llsyjIMvdOvLaBjJlHXbtj8U+cUvVPfeO0GbL5SvxAa4NNeuVT3mmC36oKH15BNFoehV67/BL7ssdwOOli9XPfVUd96DD1b97LO4An4aWDFYX0bm+FDWdbMgJdRb+a7cEl1/nVGU5HqWL3fZJ5s0Ub3nnhT15ONLLwVFo+hVnXK/8EJ3Zaed5pR/NnnpJdUOHVxq07/8JcHglWSxvm3bZi9Dn1FYpFHWmze7iUh23lkLYoRnPdKNTYl7Rr780n3dlJWpjh/vo56YL6N8p6gUvar7hL3uOnd1vXtnp/Fv3Kg6YoR79vbYQ3XatCQFUzXUbOXcNoqOd991zebSS8OWJGD8jHj1vnpnznRGV6tWqi++mEE98UZXnlJ0ir6OBx5wlvZ++7lO0qD4/HPVnj3d3Tv77CShbXWka6jmmjECYuhQl/l6+vSwJQkQP/lrRPT1113/2M47q37wQQPqKYCv51SKvqDTFA8ZAk8/DXPmwMEHw4wZja/zySfhgANg1ix49FG45x5Sj65Llx42WQpWw8iQv/7Vpdg9+2xYtChsaQIiWYrvGJ5odw6//CW0awdvvgn775+knlSsX+9G2hYqyd4AYS3ZGDD13nuqu+yiut12qi+80LA6Nm5UveQS9/KvrHRpYn2RLh9HupnuzZVjZMDYsc4/3aKFm1s2Xya2TklsquK4L+Rbml6uIrV68MGqS5akqSfZCOMCiXCjWF03scyf7zppSktV//3vzI5duNClDQbViy5S/f77DE9eVZW4kaX6XLTOWaOBzJ3rZqAC1Y4dVV9+OWyJAsRT+ptpope1Gq2geuKJPl9ojTG68gBT9B7ffqv6y1+6q/797/2FX06Zolpertqypeq4cY0UIMjh3YaRhldeUe3WzTW3kSO90dn5QJrnZN061f793eNw4YVu9GtGdWdqdOUJpuhj2LjR5fEGZ/WsX5+4XG2t6j//6WJxf/zjEPLD2wAqIwDWrFE980zXdHr2dHMRR5o0cfPfjHpce/Rwm26+uREDIwvQLWqKPo7aWtV//MP9jysrt47IWbfOxeCD6sknq65enXWRtsYseiNAHn3UTVLSurXqhAlhS5OCFNExM9hLK5inLVirT176cv3jClBxZ4op+iQ89ZRzyey6q4tDXrdOddQoN0WbiIvFDyOVgqqaj94InHnz3MCqOpdHxn1NuSDJl+wzHKOtWKW7sEin0r3+s2DPiqqaok/Jxx+rVlS4LJPt2ukPn7gNjc7xhV/rw6wUI2A2bNgSPdazp+qCBWFLFEecRV8LegsXahNqdH/e1y9pv2V/SYl7NkpKEn8FeHPDFsvzY4o+DUuWqPbpo3r88a4DK6tWvFkfRlA0whAYP96NIG3XTvX557MmYebEPB8baKrncLeLrGGCriXzaQMTPmcFakCZoo8S5ns3giAAg+HTT13IcZMmIWR8TaVsq6p0cYcD9RBedRFy/GVLHvnGLm3bFqyhZYo+SviNpilQq8MIiIAMhrVrt2RePeWUNOk8giLNS+qDD1z8f/Pmqo9c8Fr6gU5BLAVgaJmijxJ+HlBz7xjp8GMw+DQWamtVb7jBFfvJT1wup6yS4hmornbzOXToEJcoMPZakvnkG7MUQNiyKfoo4UeJm3vHSIef+VL9pMeO4dln3fSX5eWqr72WRdkTvKQ2UaKX8w8F1cMOSzMtaLoRrqlcNKm+DvL8y7nRih44BpiNmxP2qgT7mwHjvP3vABVx+zsCa4Er0p2r4BW9anpLywZLGelIZzBkkh47hk8/daNpy8pc9tesECfbYsr1F7yooDp8uM8RvCny36TsdE33ksjjL+dGKXqgBDcpeBegDPgI2CuuzPnAXd7vgcC4uP2P4yYPN0XvB7PoDT+kMhgakR57xQo3jwOo/u53WZipLUbZvkVPbc+X2pz1+sDQNxpeXyb9WbEviQJ6zhqr6HsBk2PWRwAj4spMBnp5v0uBZYB46ycCNwJXm6L3ifnojcaSLv96mq/DjRtVhw1zRU85JXmqkIZS+1CV3t7mD9qUDdq5dL5+cN1/gz2BHwrsyzmVoveTj7498GXM+kJvW8IyqloDrALaisi2wO+Aa1KdQESGisg0EZm2dOlSHyIVOIMHw+jR0KkTiLi/o0e77Ybhh3R53NPMk9C0Kdx5J/z9727ehd69IahHc80aGPT0YIav+DNHHVfGe0s6sv/v+wRTeR3V1VBRAU2auL/V1VuXSXYP0s0hkY8kewPULcAA4N6Y9dOB2+PKzAA6xKzPBdoB/wBO8bZdjVn0hpE7AsrU+NhjLtSxa1fVOXMaJ9LHH6v+6Ecudv9vf8uCW0jV/xdxgX05E5brBngN+MJbvgVWAMNTnc8UvWEEQKzfum1btzRiTMYbb6i2aaO6446qU6dmLk5trero0e6FsfPOqi+9lHkdvsmkj6uAxqs0VtGXAvOAzmzpjN07rswF1O+MfTRBPWbRG0YuqKpyYTOxSq6srNFKbNYspwtbtlSdNMn/catWqQ4c6MQ48sg0oZNBUGC+d7+kUvRpffTqfO7DPat9lqfEZ4rItSLS1yt2H84nPwe4DLgqXb1FgR8/YSblDMNPW7n4Yti4sf62jRvd9kaw557w1lvQtSscfzw88kj6Y95+283h+thjbk7bSZNgp50aJUZ6isn37pdkb4CwloKx6BvjJ4yZaCGfPyWNgPHbptJF2zSyXX37rRvUJKJ6++2Jy9TUuDTfJSXudK+/3uDTZU6B+d79go2MDQG/fsJ0YXBF0EANn/htU6naU0Dtav161b59XVWXXlp/ztb581UPP9ztO/VU1ZUrG3yahlNAvne/pFL0dbHukaGyslKnTZsWthiNp0kT90jFIwK1tenLxdKpE3zxRaDiGXmI3zbVrh0sX56+vka2q5oa5w264w7YfXcXAfzNN3D++bB5M9x2GwwZ4sQzso+IvKeqlYn2+YmjNxqCXz+hH7/hggWNl8fIf9K1qTr/vR8lD41uV6WlMGoUTJnilHnv3m6oxz77wEcfwW9+Y0o+KpiizxaJBqy0aOG2pysXTzF3IhlbSNWmqqth6FCYP3/LvjotW1KSuL6A2tUvfgEffwzXXAM33ggvvwxdugRStREUyXw6YS0F46NXzXzKwNiOWPPRG4lI1qZS+e+LtHOy2MB89HlEdTWMHOk+qzt2dNaapT4w0pHMR1Lnv7d2VfCk8tGbojeMfKe6Gk4/PXFHrXXkFw3WGZtrbACUkUtGjkwejRPfJxS1thk1eQqU0rAFKDjqOsXWr3fr8+e7dbBPZSM7JIueUa3f5qLWNqMmTwFjrpugqaioH/lQh31CG9nCb5uLWtuMmjx5jrluckky68pi4Y1s4TeUN2ptM2ryFDCm6IPGEioZucbvRDVRa5tRk6eAMUUfNH6tK8MIksGDnbujttb9TeTjjlrbTCePddQGR7IA+7CWghgw1diESkWYkMnIEVFrW8nksUFeGYMNmMoj4iMRwFk5NmesUUxYR23GWGdsPjFyZH0lD2595Mhw5DGMMLCO2kAxRR81rIEbxUqsT75JEtUU31Frfnxf2ICpqNGxY+JPVotEMAqZeJfl5s1bl4nvOLYBV77xZdGLyDEiMltE5ojIVvPBikgzERnn7X9HRCq87T1E5ENv+UhETgpW/AIkapERhpELErkswaVYThYyam5O36S16EWkBBgFHAksBKaKyERV/SSm2FnASlXtKiIDgRuAU4EZQKWq1ojILsBHIvIfdROOG4moa8iWadAoJpK5Jmtr68+e5ecYc3NuhR+LvgcwR1XnqepGYCzQL65MP2CM9/txoLeIiKquj1HqzYFohfhEFT8x0YZRSDRk8JQNuPKNH0XfHvgyZn2hty1hGU+xrwLaAohITxGZCUwHzk1kzYvIUBGZJiLTli5dmvlVGIaR3zTEZWluTt9kPepGVd9R1b2Bg4ARItI8QZnRqlqpqpXl5eXZFskwipeoRqn4TePQ2GOKFD9RN4uA3WLWO3jbEpVZKCKlQGug3gzFqjpLRNYC+wBFPCLKMEIi6lEqgwdnLkdDjilC/Fj0U4FuItJZRMqAgcDEuDITgSHe7wHAFFVV75hSABHpBOwJfBGI5IZhZIZFqRQtaRW951MfDkwGZgGPqupMEblWRPp6xe4D2orIHOAyoC4E8xBcpM2HwBPA+aq6LOiLKCqi+ultRB+LUilaLNdNPmF5cIzGYPljChrLdVMo2Ke30RgsSqVoMUUfNpm4YuzT22gMFqVStFiumzDJNArC8uAYjcWiVIoSs+izTSqLPVNXjH16G4bRAEzRZ5M6i33+fDdHTp3FXqfsM3XF2Ke3YRgNwKJuskm6KAeLgjAMIyAs6iYs0lns5ooxDCMHmKLPJumy65krxjCMHGCKPpv4sdgtJbFhGFnGFH02MYvdMIwIYHH02cbilg3DCBmz6A3DMAocU/SGYRgFjil6wzCMAscUvWEYRoFjit4wDKPAMUVvGIZR4PhS9CJyjIjMFpE5InJVgv3NRGSct/8dEanwth8pIu+JyHTv7xHBim8YhmGkI62iF5ESYBRwLLAXMEhE9oordhawUlW7AjcDN3jblwEnqOq+uMnDHwpKcMMwigybL7nB+LHoewBzVHWeqm4ExgL94sr0A8Z4vx8HeouIqOoHqvqVt30msI2INAtCcMMwioh0Kb+NlPhR9O2BL2PWF3rbEpZR1RpgFdA2rszJwPuquiH+BCIyVESmici0pUuX+pXdMIxiweZLbhQ56YwVkb1x7pxhifar6mhVrVTVyvLy8lyIZBhGlIl30ySatwFsvmSf+Ml1swjYLWa9g7ctUZmFIlIKtAaWA4hIB+AJ4AxVndtoiQ3DKGwSzaUs4lw28dh8yb7wY9FPBbqJSGcRKQMGAhPjykzEdbYCDACmqKqKyPbAf4GrVPWNoIQ2DKMASNa5mshNo+qUfSw2SY9v0lr0qlojIsOByUAJcL+qzhSRa4FpqjoRuA94SETmACtwLwOA4UBX4E8i8idv21GquiToCzEMI49IZLUPHep+J3PHqLpU3wsWOEv+uussM6xPbM5YwzByT6r5ksHmUm4ANmesYRjRItV8yjaXcuCYojcMI/ekmk/ZZmYLHFP0hmHknnRWu82lHCim6A3DyD1mtecUmzPWMIxwsPmUc4ZZ9IZhGAWOKXrDMIwCxxS9YRhGgWOK3jAMo8AxRW8YhlHgRC4FgogsBZLkJKUdbtaqqJMvckL+yGpyBku+yAn5I2vYcnZS1YR53iOn6FMhItOS5XKIEvkiJ+SPrCZnsOSLnJA/skZZTnPdGIZhFDim6A3DMAqcfFP0o8MWwCf5Iifkj6wmZ7Dki5yQP7JGVs688tEbhmEYmZNvFr1hGIaRIaboDcMwCpy8UPQi8isRmSkitSJSGbO9QkS+E5EPveWuKMrp7RshInNEZLaIHB2WjPGIyNUisijmHvYJW6ZYROQY757NEZGrwpYnFSLyhYhM9+5jZObDFJH7RWSJiMyI2dZGRJ4Xkf95f3cIU0ZPpkRyRq59ishuIvKSiHziPe8Xe9sjd0/ryAtFD8wA+gOvJtg3V1X395ZzcyxXPAnlFJG9cBOm7w0cA9whIiW5Fy8pN8fcw2fCFqYO7x6NAo4F9gIGefcyyvzCu49Riqd+ANfuYrkKeFFVuwEveuth8wBbywnRa581wOWquhdwMHCB1y6jeE+BPFH0qjpLVWeHLUc6UsjZDxirqhtU9XNgDtAjt9LlJT2AOao6T1U3AmNx99LIAFV9FVgRt7kfMMb7PQY4MadCJSCJnJFDVb9W1fe932uAWUB7InhP68gLRZ+GziLygYi8IiKHhi1MEtoDX8asL/S2RYXhIvKx9+kcmc9Non/f4lHgORF5T0SGhi1MGnZS1a+9398AO4UpTBqi2j4RkQrgAOAdInxPI6PoReQFEZmRYEllwX0NdFTVA4DLgIdFZLsIyhkqaWS+E9gd2B93P/8ZqrD5zSGqeiDO1XSBiBwWtkB+UBdjHdU468i2TxHZFhgPXKKqq2P3Re2eRmYqQVX9ZQOO2QBs8H6/JyJzgT2ArHWENUROYBGwW8x6B29bTvArs4jcAzydZXEyIdT7limqusj7u0REnsC5nhL1K0WBxSKyi6p+LSK7AEvCFigRqrq47neU2qeINMUp+WpVneBtjuw9jYxF3xBEpLyuU1NEugDdgHnhSpWQicBAEWkmIp1xcr4bskwAeA2yjpNwHcpRYSrQTUQ6i0gZrkN7YsgyJUREWopIq7rfwFFE617GMxEY4v0eAjwVoixJiWL7FBEB7gNmqepNMbuie09VNfIL7h+8EGe9LwYme9tPBmYCHwLvAydEUU5v30hgLjAbODbsexoj10PAdOBjXEPdJWyZ4uTrA3zm3buRYcuTQs4uwEfeMjNKsgKP4Nwem7z2eRbQFhcZ8j/gBaBNROWMXPsEDsG5ZT72dM+HXjuN3D2tWywFgmEYRoGT164bwzAMIz2m6A3DMAocU/SGYRgFjil6wzCMAscUvWEYRoFjit4wDKPAMUVvGIZR4Pw/xL5LMsrWiHQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for p in net.parameters():\n",
    "    p.data.normal_(std=0.1)\n",
    "\n",
    "def featurize(x):\n",
    "    return torch.cat([x[:, None], x[:, None]**2], dim=1)\n",
    "    \n",
    "x = torch.cat([torch.linspace(-10, -6, 40), torch.linspace(6, 10, 40), torch.linspace(14, 18, 40)])\n",
    "# x = torch.cat([torch.linspace(-10, -6, 20), torch.linspace(6, 10, 20)])\n",
    "f = featurize(x)\n",
    "x_ = torch.linspace(-14, 22, 100)\n",
    "f_ = featurize(x_)\n",
    "\n",
    "y = net(f)\n",
    "y_ = net(f_)\n",
    "\n",
    "y += torch.randn_like(y) * .01\n",
    "plt.plot(x.data.numpy(), y.data.numpy(), \"ro\")\n",
    "plt.plot(x_.data.numpy(), y_.data.numpy(), \"-b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.savez(\"data\", x=x.data.numpy(), y=y.data.numpy(), x_=x_.data.numpy(), y_=y_.data.numpy())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "py37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
